import ray
import gym
import numpy as np

from agentic_system.environments.env_package.mirrorapi.mirrirapi.mirrorapi_agent_env import MirrirAPIAgentSiteEnv


# -----------------------------------------------------------------------------
# Ray remote worker actor -----------------------------------------------------
# -----------------------------------------------------------------------------

@ray.remote(num_cpus=0.25)
class MirrorAPIWorker:
    """Ray remote actor that replaces the worker function.
    Each actor hosts a *WebAgentTextEnv* instance.
    """
    
    def __init__(self, seed, max_steps):
        # Lazy import avoids CUDA initialisation issues

        # env_kwargs['seed'] = seed
        self.max_step = max_steps
        self.env = MirrirAPIAgentSiteEnv(max_steps=self.max_step)
    
    def step(self, action, invoke, valid, his_text_ob):
        """Execute a step in the environment"""
        
        obs, reward, done, info = self.env.step(action, invoke, valid, his_text_ob)
        
        # info['available_actions'] = self.env.get_available_actions()
        info['task_score'] = reward

        # Redefine reward. We only use rule-based reward - win for 10, lose for 0.
        if done and reward >= 1.0:
            info['won'] = True
        #     reward = 10.0
        else:
            info['won'] = False
        #     reward = 0

        return obs, reward, done, info
    
    def reset(self, idx):
        """Reset the environment with given session index"""
        obs, info = self.env.reset(index=idx)
        # info = 
        # info['available_actions'] = info

        # print(f"ENV @@@@ {info}")

        return obs, info
    
    # def render(self, mode_for_render):
    #     """Render the environment"""
    #     rendered = self.env.render(mode=mode_for_render)
    #     return rendered
    #
    # def get_available_actions(self):
    #     """Get available actions"""
    #     return self.env.get_available_actions()
    
    def initialize_and_get_goals(self):
        """Initialize environment and get environment goals"""
        # self.env.initialize()
        return self.env.server.goals
    
    def initialize_with_goals(self, goals):
        """Initialize environment with specified goals"""
        # self.env.initialize(goals)
        pass

    def close(self):
        """Close the environment"""
        # self.env.close()
        return


# -----------------------------------------------------------------------------
# Vectorised Ray environment --------------------------------------------------
# -----------------------------------------------------------------------------

class MirrorAPIMultiProcessEnv():
    """A vectorised, Ray-based wrapper around *WebAgentTextEnv*.

    ``info`` dictionaries returned by :py:meth:`step` **and** :py:meth:`reset`
    automatically contain the key ``'available_actions'`` so downstream RL code
    can obtain the *legal* action set without extra IPC overhead.
    """
    def __init__(
        self,
        seed: int = 0,
        env_num: int = 1,
        group_n: int = 1,
        is_train: bool = True,
        env_kwargs: dict = None,
    ) -> None:
        super().__init__()

        # Initialize Ray if not already initialized
        if not ray.is_initialized():
            ray.init()

        self.group_n = group_n
        self.env_num = env_num
        self.num_processes = env_num * group_n
        self.is_train = is_train
        if self.is_train:
            self.start_idx = 0
            self.data_len = 12800
        else:
            self.start_idx = 12800
            self.data_len = 12000

        self._rng = np.random.RandomState(seed)

        self._env_kwargs = env_kwargs if env_kwargs is not None else {}

        # -------------------------- Ray actors setup --------------------------
        self._workers = []

        for i in range(self.num_processes):
            worker = MirrorAPIWorker.remote(seed + (i // self.group_n), int(self._env_kwargs["max_steps"]) if "max_steps" in self._env_kwargs else 6)
            self._workers.append(worker)

        # Get goals from the first worker
        # goals_future = self._workers[0].initialize_and_get_goals.remote()
        # goals = ray.get(goals_future)

        # Initialize the remaining workers 
        # init_futures = []
        # for i in range(1, self.num_processes):
        #     init_futures.append(self._workers[i].initialize_with_goals.remote(goals))
        # ray.get(init_futures)

        # ------- original ----------#
        # if args.num is None:
        #     if split == 'test':
        #         self.goal_idxs = range(500)
        #     elif split == 'eval':
        #         self.goal_idxs = range(500, 1500)
        #     elif split == 'train':
        #         self.goal_idxs = range(1500, len(self.env.server.goals))
        # else:
        #     self.goal_idxs = range(len(self.env.server.goals))

        # if not self.is_train:
        #     self.goal_idxs = range(500)
        # else:
        #     self.goal_idxs = range(500, len(goals))
            
        # print(self.goal_idxs)

    # ------------------------------------------------------------------
    # Base API ----------------------------------------------------------
    # ------------------------------------------------------------------

    def step(self, actions: list[str], invokes: list[str], valids: list[int], his_text_obs):
        if not his_text_obs:
            his_text_obs = [" "] * len(actions)
        # print(f"### action ### {actions[0]}")
        if len(actions) != self.num_processes or len(invokes) != self.num_processes:
            raise ValueError(
                f'Expected {self.num_processes} actions, got {len(actions)}',
            )

        # Send step commands to all workers
        futures = []
        for worker, action, invoke, valid, his_text_ob in zip(self._workers, actions, invokes, valids, his_text_obs):
            future = worker.step.remote(action, invoke, valid, his_text_ob)
            futures.append(future)

        # Collect results
        try:
            results = ray.get(futures)
        except:
            print(f"#### actions {actions} len {len(actions)}")
            print(f"#### invokes {invokes} len {len(invokes)}")
            print(f"#### valids {valids} len {len(valids)}")
            print(f"#### workers len {len(self._workers)}")

        obs_list, reward_list, done_list, info_list = [], [], [], []
        for obs, reward, done, info in results:
            obs_list.append(obs)
            reward_list.append(reward)
            done_list.append(done)
            info_list.append(info)

        return obs_list, reward_list, done_list, info_list

    def reset(self):
        # TODO 修改为正常数值
        idx = self._rng.choice(self.data_len, size=self.env_num, replace=False)
        idx = idx + self.start_idx
        idx = np.repeat(idx, self.group_n).tolist()

        # Send reset commands to all workers
        futures = []
        for worker, i in zip(self._workers, idx):
            future = worker.reset.remote(i)
            futures.append(future)

        # Collect results
        results = ray.get(futures)
        obs_list, info_list = [], []
        for obs, info in results:
            obs_list.append(obs)
            info_list.append(info)
            # print(f"ENV #### {info}")

        return obs_list, info_list

    # ------------------------------------------------------------------
    # Convenience helpers ----------------------------------------------
    # ------------------------------------------------------------------

    def render(self, mode: str = 'text', env_idx: int = None):
        if env_idx is not None:
            future = self._workers[env_idx].render.remote(mode)
            return ray.get(future)

        futures = []
        for worker in self._workers:
            future = worker.render.remote(mode)
            futures.append(future)
        
        return ray.get(futures)

    # ------------------------------------------------------------------
    # Clean‑up ----------------------------------------------------------
    # ------------------------------------------------------------------

    def close(self):
        # if getattr(self, '_closed', False):
        #     return

        # Close all workers and kill Ray actors
        close_futures = []
        # print(f"### WORKERS {self._workers}")
        for worker in self._workers:
            future = worker.close.remote()
            close_futures.append(future)
        
        # Wait for all workers to close
        ray.get(close_futures)
        
        # Kill all Ray actors
        for worker in self._workers:
            ray.kill(worker)
            
        self._closed = True

    def __del__(self):  # noqa: D401
        self.close()


# -----------------------------------------------------------------------------
# Factory helper --------------------------------------------------------------
# -----------------------------------------------------------------------------

def build_mirrorapi_envs(
    seed: int = 0,
    env_num: int = 1,
    group_n: int = 1,
    is_train: bool = True,
    env_kwargs: dict = None,
):
    """Mirror *build_sokoban_envs* so higher‑level code can swap seamlessly."""
    return MirrorAPIMultiProcessEnv(
        seed=seed,
        env_num=env_num,
        group_n=group_n,
        is_train=is_train,
        env_kwargs=env_kwargs,
    )